import torch
from torch import nn

from evaluator import Evaluator
from network import get_model, get_resnet18
from trainer import Trainer
from util import try_gpu

batch_size = 128
lr = 0.05
num_epoch = 300
wid_mul = 0.25

def main():
    model = get_model(wid_mul=wid_mul, attention=False)
    model.to(device=try_gpu())
    
    train_evaluator = Evaluator(batch_size, is_train=True)
    test_evaluator = Evaluator(batch_size, is_train=False)
    trainer = Trainer(model, train_evaluator=train_evaluator, test_evaluator=test_evaluator, lr=lr, num_epoch=num_epoch, batch_size=batch_size)
    trainer.train()

if __name__ == '__main__':
    main()